/*
 * Copyright 2024 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.google.tsunami.plugins.rce.flyteconsole;

import com.google.common.flogger.GoogleLogger;
import com.google.common.util.concurrent.Uninterruptibles;
import com.google.protobuf.Duration;
import com.google.protobuf.Struct;
import flyteidl.admin.ExecutionOuterClass.Execution;
import flyteidl.admin.ExecutionOuterClass.ExecutionCreateRequest;
import flyteidl.admin.ExecutionOuterClass.ExecutionCreateResponse;
import flyteidl.admin.ExecutionOuterClass.ExecutionMetadata;
import flyteidl.admin.ExecutionOuterClass.ExecutionMetadata.ExecutionMode;
import flyteidl.admin.ExecutionOuterClass.ExecutionSpec;
import flyteidl.admin.ExecutionOuterClass.WorkflowExecutionGetRequest;
import flyteidl.admin.ProjectOuterClass;
import flyteidl.admin.ProjectOuterClass.Project;
import flyteidl.admin.ProjectOuterClass.ProjectListRequest;
import flyteidl.admin.ProjectOuterClass.ProjectRegisterRequest;
import flyteidl.admin.ProjectOuterClass.ProjectRegisterResponse;
import flyteidl.admin.ProjectOuterClass.Projects;
import flyteidl.admin.TaskOuterClass.TaskCreateRequest;
import flyteidl.admin.TaskOuterClass.TaskCreateResponse;
import flyteidl.admin.TaskOuterClass.TaskSpec;
import flyteidl.core.Execution.WorkflowExecution.Phase;
import flyteidl.core.IdentifierOuterClass.Identifier;
import flyteidl.core.IdentifierOuterClass.ResourceType;
import flyteidl.core.IdentifierOuterClass.WorkflowExecutionIdentifier;
import flyteidl.core.Interface.TypedInterface;
import flyteidl.core.Interface.VariableMap;
import flyteidl.core.Literals;
import flyteidl.core.Literals.Literal;
import flyteidl.core.Literals.Primitive;
import flyteidl.core.Literals.RetryStrategy;
import flyteidl.core.Literals.Scalar;
import flyteidl.core.Tasks;
import flyteidl.core.Tasks.Container;
import flyteidl.core.Tasks.DataLoadingConfig;
import flyteidl.core.Tasks.RuntimeMetadata;
import flyteidl.core.Tasks.RuntimeMetadata.RuntimeType;
import flyteidl.core.Tasks.TaskMetadata;
import flyteidl.core.Tasks.TaskTemplate;
import flyteidl.service.AdminServiceGrpc;
import flyteidl.service.AdminServiceGrpc.AdminServiceBlockingStub;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Arrays;

/**
 * FlyteProtoClient is a gRPC client for interacting with the Flyte Admin service. It allows you to
 * list projects, workflows, and potentially other entities in a Flyte deployment.
 */
public class FlyteProtoClient {

  private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
  private static final String MY_TASK_TYPE = "container";

  private static final String CONTAINER_NAME = "docker.io/nginx:latest";
  private static final String INPUT_PATH = "/tmp";
  private static final String OUT_PATH = "/tmp";
  private static final String SHELL_PATH = "sh";
  private static final String DOMAIN = "development";
  private static final int WAIT_TIME_SECS_FOR_EXECUTION_STATUS = 5;

  private static final String PROJECT_NAME = "flytesnacks";
  private static final String TASK_NAME = "tsunamirce";
  private static final int TASK_EXECUTION_TIMEOUT_SECS = 180;
  // Stub generated by gRPC that allows remote procedure calls (RPC) to the Flyte
  // Admin service.
  AdminServiceBlockingStub flyteService;

  FlyteProtoClient() {}

  /**
   * Sets the gRPC stub for interacting with the Flyte Admin service.
   *
   * @param stub The AdminServiceBlockingStub instance to be set.
   */
  public void setStub(AdminServiceBlockingStub stub) {
    this.flyteService = stub;
  }

  /**
   * Establishes a connection to the Flyte server and initializes the gRPC service stub.
   *
   * @param url The URL of the Flyte server to connect to. This should include both the host and
   *     port in the format "http://host:port".
   * @throws URISyntaxException
   */
  public void buildService(String url) throws URISyntaxException {

    URI uri = new URI(url);
    String target = String.format("%s:%s", uri.getHost(), uri.getPort());
    // Managed channel for establishing a connection to the Flyte server.
    ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(target);
    if (uri.getScheme() == "https") {
      channelBuilder.useTransportSecurity();
    } else {
      channelBuilder.usePlaintext();
    }
    ManagedChannel channel = channelBuilder.enableRetry().build();

    this.flyteService = AdminServiceGrpc.newBlockingStub(channel);
  }

  private static Literal asLiteral(Literals.Primitive primitive) {
    Scalar scalar = Scalar.newBuilder().setPrimitive(primitive).build();
    return Literal.newBuilder().setScalar(scalar).build();
  }

  public static Literals.Literal ofString(String value) {
    Primitive primitive = Primitive.newBuilder().setStringValue(value).build();
    return asLiteral(primitive);
  }

  /**
   * Waits for an execution ID to be generated for a specified task.
   *
   * @param project The project in which the task is to be run.
   * @param taskName The name of the task.
   * @param taskVersion The version of the task.
   * @param maxTimeOutInSecs The maximum time to wait for the execution ID in seconds.
   * @return The execution ID if generated within the timeout period; otherwise, returns null.
   */
  public String waitForTheExecutionId(
      String project, String taskName, String taskVersion, int maxTimeOutInSecs) {

    int waitTimeInSecs = 20;
    int loops = maxTimeOutInSecs / waitTimeInSecs;

    for (int i = 0; i < loops; i++) {
      String executionId = this.runTask(project, taskName, taskVersion);
      if (executionId != null) {
        return executionId;
      }
      logger.atFine().log("Unable to run the task in flyte, retrying");
      Uninterruptibles.sleepUninterruptibly(java.time.Duration.ofSeconds(waitTimeInSecs));
    }
    return null;
  }

  /**
   * Waits for a script execution to finish in the Flyte Console.
   *
   * <p>This method checks the status of a Flyte task execution at regular intervals until the task
   * is completed or the maximum timeout is reached. It logs a message when the task completes
   * successfully.
   *
   * <p>The method divides the total maximum timeout into smaller intervals (default is
   * WAIT_TIME_SECS_FOR_EXECUTION_STATUS seconds) and checks the execution status in each interval.
   * If the task is found to be complete, the loop is terminated early.
   *
   * @param project The name of the project in Flyte where the task is executed.
   * @param executionId The unique identifier of the task execution.
   * @param maxTimeOutInSecs The maximum time (in seconds) to wait for the task to complete.
   */
  public void waitForTheScriptToFinish(String project, String executionId, int maxTimeOutInSecs) {

    int waitTimeInSecs = WAIT_TIME_SECS_FOR_EXECUTION_STATUS;

    // Calculate the number of loops based on the maximum timeout.
    int loops = maxTimeOutInSecs / waitTimeInSecs;
    for (int i = 0; i < loops; i++) {
      boolean status = this.checkExecutionRunning(project, executionId);
      // If the task is completed break the loop.
      if (status == false) {
        logger.atInfo().log("Task completed successfully in the Flyte Console");
        break;
      }
      // Wait for a defined interval before checking the status again.
      Uninterruptibles.sleepUninterruptibly(
          java.time.Duration.ofSeconds(WAIT_TIME_SECS_FOR_EXECUTION_STATUS));
    }
  }

  /**
   * Executes a shell script using Flyte Console by creating a project, task, and waiting for the
   * task's execution.
   *
   * @param shellScript The shell script to be executed.
   * @param maxTimeout The maximum time in seconds to wait for the script to complete.
   */
  public void runShellScript(String shellScript, int maxTimeout) {

    // Generate a unique task version using the current timestamp.
    String taskVersion = String.format("v%d", (int) (System.currentTimeMillis() / 1000));

    // 1.Check if the project exists in Flyte Console.
    boolean projectExists = this.checkProjectExists(PROJECT_NAME);

    // 2. If the project does not exist, attempt to create it.
    if (!projectExists) {

      projectExists = this.createProject(PROJECT_NAME);
    }
    if (projectExists) {
      // 3. Create a Task in the Flyte Console with the provided shell script.
      boolean taskCreated = this.createTask(PROJECT_NAME, TASK_NAME, taskVersion, shellScript);

      if (taskCreated) {
        // 4. Run the task and wait for an execution ID.
        // When a new project is created, configuring the task may take some time. We
        // need to retry the process multiple times.
        String executionId =
            this.waitForTheExecutionId(
                PROJECT_NAME, TASK_NAME, taskVersion, TASK_EXECUTION_TIMEOUT_SECS);

        // If an execution ID is returned, wait for the script to finish executing.
        if (executionId != null) {
          this.waitForTheScriptToFinish(PROJECT_NAME, executionId, maxTimeout);
        } else {
          logger.atSevere().log(
              "Unable to run task in Flyte Console with project: %s , task:%s, version:%s",
              PROJECT_NAME, TASK_NAME, taskVersion);
        }
      } else {
        logger.atSevere().log(
            "Unable to create task in Flyte Console with project: %s , task:%s, version:%s",
            PROJECT_NAME, TASK_NAME, taskVersion);
      }

    } else {
      logger.atSevere().log(
          "Unable to create project in Flyte Console with name : %s", PROJECT_NAME);
    }
  }

  /**
   * Checks if a project with the specified ID exists in the Flyte Console.
   *
   * <p>This method sends a request to the Flyte Admin Service to retrieve a list of all projects.
   * It then iterates through the list to check if any project matches the given project ID. If a
   * match is found, the method returns true, indicating that the project exists. If no match is
   * found, it returns false.
   *
   * @param project_id The ID of the project to check for existence.
   * @return true if the project exists, false otherwise.
   */
  public boolean checkProjectExists(String project_id) {
    ProjectListRequest request = ProjectListRequest.newBuilder().build();
    // Send the request and retrieve the list of projects.
    Projects projects = flyteService.listProjects(request);
    for (int i = 0; i < projects.getProjectsCount(); i++) {
      ProjectOuterClass.Project project = projects.getProjects(i);
      // Check if the current project's ID matches the provided project ID.
      if (project.getId().equals(project_id)) {
        // Project exists, return true.
        return true;
      }
    }
    // If no matching project ID is found, return false.
    return false;
  }

  /**
   * Creates a new project in the Flyte Console with the specified name.
   *
   * <p>This method constructs a new `Project` object with the provided project name and sends a
   * request to the Flyte Admin Service to register the project. After attempting to register the
   * project, the method checks if the project was successfully created by verifying its existence
   * in the project list.
   *
   * @param projectName The name of the project to be created.
   * @return true if the project was successfully created and exists, false otherwise.
   */
  public boolean createProject(String projectName) {

    Project project = Project.newBuilder().setId(projectName).setName(projectName).build();

    ProjectRegisterRequest request =
        ProjectRegisterRequest.newBuilder().setProject(project).build();
    // Send the registration request and receive a response.
    ProjectRegisterResponse response = flyteService.registerProject(request);
    if (response != null) {
      return this.checkProjectExists(projectName);
    }
    // Return false if the project registration failed.
    return false;
  }

  /**
   * Creates a new task in the Flyte Console within a specified project.
   *
   * <p>This method builds a task with the given project, task name, version, and shell script. The
   * task is configured to spawn a Docker container in the cluster, execute the shell script, and
   * handle basic retry and runtime settings. The task creation request is sent to the Flyte Admin
   * Service, and the method returns true if the task is successfully created.
   *
   * @param project The name of the project in which the task will be created.
   * @param taskName The name of the task to be created.
   * @param version The version of the task.
   * @param shellScript The shell script to be executed by the task.
   * @return true if the task was successfully created, false otherwise.
   */
  public boolean createTask(String project, String taskName, String version, String shellScript) {

    try {
      logger.atFine().log(
          "Creating task  with project=%s, task=%s, version=%s, shellScript=%s",
          project, taskName, version, shellScript);

      Identifier taskId =
          Identifier.newBuilder()
              .setResourceType(ResourceType.TASK)
              .setDomain(DOMAIN)
              .setProject(project)
              .setName(taskName)
              .setVersion(version)
              .build();

      TypedInterface taskInterface =
          TypedInterface.newBuilder()
              .setInputs(VariableMap.newBuilder().build())
              .setOutputs(VariableMap.newBuilder().build())
              .build();

      RetryStrategy RETRIES = RetryStrategy.newBuilder().setRetries(1).build();

      // Define the container that will be spawned in the cluster to run the shell
      // script.
      Container container =
          Container.newBuilder()
              .setImage(CONTAINER_NAME)
              .setDataConfig(
                  DataLoadingConfig.newBuilder()
                      .setInputPath(INPUT_PATH)
                      .setOutputPath(OUT_PATH)
                      .build())
              .addAllArgs(Arrays.asList("-c", shellScript))
              .addCommand(SHELL_PATH)
              .build();

      RuntimeMetadata runMetadata =
          RuntimeMetadata.newBuilder()
              .setType(RuntimeType.FLYTE_SDK)
              .setVersion("0.0.1")
              .setFlavor("java")
              .build();

      TaskMetadata metadata =
          TaskMetadata.newBuilder()
              .setDiscoverable(false)
              .setCacheSerializable(false)
              .setTimeout(Duration.newBuilder().setSeconds(180).build())
              .setRetries(RETRIES)
              .setRuntime(runMetadata)
              .build();

      Tasks.TaskTemplate taskTemplate =
          TaskTemplate.newBuilder()
              .setType(MY_TASK_TYPE)
              .setInterface(taskInterface)
              .setContainer(container)
              .setMetadata(metadata)
              .setCustom(Struct.newBuilder().build())
              .build();

      TaskCreateRequest request =
          TaskCreateRequest.newBuilder()
              .setId(taskId)
              .setSpec(TaskSpec.newBuilder().setTemplate(taskTemplate).build())
              .build();

      TaskCreateResponse response = flyteService.createTask(request);
      // Return true if the task creation was successful, otherwise return false.
      if (response != null) {
        return true;
      }
    } catch (Exception e) {
      logger.atSevere().log(
          "Exception while creating task with project: %s , task:%s, version:%s",
          project, taskName, version);
    }

    return false;
  }

  /**
   * Runs a task in Flyte Console by creating an execution request and returning the execution ID.
   *
   * @param project The name of the project where the task is to be run.
   * @param taskName The name of the task to be executed.
   * @param version The version of the task to be executed.
   * @return The execution ID if the task is successfully executed; otherwise, returns null.
   */
  public String runTask(String project, String taskName, String version) {

    try {
      logger.atFine().log(
          "Running Task  with project=%s, task=%s, version=%s,", project, taskName, version);

      Identifier launchId =
          Identifier.newBuilder()
              .setResourceType(ResourceType.TASK)
              .setProject(project)
              .setDomain(DOMAIN)
              .setName(taskName)
              .setVersion(version)
              .build();

      ExecutionMetadata metadata =
          ExecutionMetadata.newBuilder()
              .setMode(ExecutionMode.MANUAL)
              .setPrincipal("flyteconsole")
              .build();

      ExecutionSpec executionSpec =
          ExecutionSpec.newBuilder()
              .setMetadata(metadata)
              .setLaunchPlan(launchId)
              // .setInputs(LiteralMap.newBuilder().build())
              .build();

      ExecutionCreateRequest request =
          ExecutionCreateRequest.newBuilder()
              .setDomain(DOMAIN)
              .setProject(project)
              .setSpec(executionSpec)
              .build();
      ExecutionCreateResponse response = flyteService.createExecution(request);
      if (response.hasId()) {
        logger.atInfo().log("Execution created with ID %s", response.getId().getName());
        // Execution ID
        return response.getId().getName();
      }
    } catch (Exception e) {
      // TODO: handle exception
    }
    return null;
  }

  private boolean isRunning(Phase phase) {
    switch (phase) {
      case SUCCEEDING:
      case QUEUED:
      case RUNNING:
      case UNDEFINED:
        return true;
      case TIMED_OUT:
      case SUCCEEDED:
      case ABORTED:
      case ABORTING:
      case FAILED:
      case FAILING:
      case UNRECOGNIZED:
        return false;
    }

    return false;
  }

  /**
   * Checks if a specific workflow execution is still running in the Flyte Console.
   *
   * <p>This method sends a request to the Flyte Admin Service to retrieve the current status of a
   * workflow execution identified by the project name and execution ID. If the response is not
   * null, it checks the phase of the execution to determine if it is still running. The method
   * returns true if the execution is running, otherwise it returns false.
   *
   * @param project The name of the project in which the workflow is executed.
   * @param executionId The unique identifier of the workflow execution.
   * @return true if the execution is still running, false otherwise.
   */
  public boolean checkExecutionRunning(String project, String executionId) {

    WorkflowExecutionGetRequest request =
        WorkflowExecutionGetRequest.newBuilder()
            .setId(
                WorkflowExecutionIdentifier.newBuilder()
                    .setName(executionId)
                    .setDomain(DOMAIN)
                    .setProject(project)
                    .build())
            .build();

    Execution resp = flyteService.getExecution(request);
    // If the response is not null, check if the execution is still running.
    if (resp.hasClosure()) {
      return this.isRunning(resp.getClosure().getPhase());
    }

    return false;
  }
}
